#!/usr/bin/env python3

"""
    chassisd
    Module information update daemon for SONiC
    This daemon will loop to collect all modules related information and then write the information to state DB.
    The loop interval is CHASSIS_INFO_UPDATE_PERIOD_SECS in seconds.
"""

try:
    import os
    import re
    import signal
    import subprocess
    import sys
    import threading
    import time
    import json
    import glob
    from datetime import datetime, timezone

    from sonic_py_common import daemon_base, logger, device_info
    from sonic_py_common.task_base import ProcessTaskBase

    # If unit testing is occurring, mock swsscommon and module_base
    if os.getenv("CHASSISD_UNIT_TESTING") == "1":
        from tests import mock_swsscommon as swsscommon
        from tests.mock_module_base import ModuleBase
    else:
        from swsscommon import swsscommon
        from sonic_platform_base.module_base import ModuleBase
except ImportError as e:
    raise ImportError(str(e) + " - required module not found")

#
# Constants ====================================================================
#

SIGNALS_TO_NAMES_DICT = dict((getattr(signal, n), n)
                             for n in dir(signal) if n.startswith('SIG') and '_' not in n)

SYSLOG_IDENTIFIER = "chassisd"

CHASSIS_CFG_TABLE = 'CHASSIS_MODULE'

CHASSIS_INFO_TABLE = 'CHASSIS_TABLE'
CHASSIS_INFO_KEY_TEMPLATE = 'CHASSIS {}'
CHASSIS_INFO_CARD_NUM_FIELD = 'module_num'

CHASSIS_MODULE_INFO_TABLE = 'CHASSIS_MODULE_TABLE'
CHASSIS_MODULE_INFO_KEY_TEMPLATE = 'CHASSIS_MODULE {}'
CHASSIS_MODULE_INFO_NAME_FIELD = 'name'
CHASSIS_MODULE_INFO_DESC_FIELD = 'desc'
CHASSIS_MODULE_INFO_SLOT_FIELD = 'slot'
CHASSIS_MODULE_INFO_OPERSTATUS_FIELD = 'oper_status'
CHASSIS_MODULE_INFO_NUM_ASICS_FIELD = 'num_asics'
CHASSIS_MODULE_INFO_ASICS = 'asics'
CHASSIS_MODULE_INFO_SERIAL_FIELD = 'serial'

CHASSIS_ASIC_INFO_TABLE = 'CHASSIS_ASIC_TABLE'
CHASSIS_FABRIC_ASIC_INFO_TABLE = 'CHASSIS_FABRIC_ASIC_TABLE'
CHASSIS_ASIC = 'asic'
CHASSIS_ASIC_PCI_ADDRESS_FIELD = 'asic_pci_address'
CHASSIS_ASIC_ID_IN_MODULE_FIELD = 'asic_id_in_module'

CHASSIS_MIDPLANE_INFO_TABLE = 'CHASSIS_MIDPLANE_TABLE'
CHASSIS_MIDPLANE_INFO_KEY_TEMPLATE = 'CHASSIS_MIDPLANE {}'
CHASSIS_MIDPLANE_INFO_NAME_FIELD = 'name'
CHASSIS_MIDPLANE_INFO_IP_FIELD = 'ip_address'
CHASSIS_MIDPLANE_INFO_ACCESS_FIELD = 'access'

CHASSIS_MODULE_HOSTNAME_TABLE = 'CHASSIS_MODULE_TABLE'
CHASSIS_MODULE_INFO_HOSTNAME_FIELD = 'hostname'

CHASSIS_MODULE_REBOOT_INFO_TABLE = 'CHASSIS_MODULE_REBOOT_INFO_TABLE'
CHASSIS_MODULE_REBOOT_TIMESTAMP_FIELD = 'timestamp'
CHASSIS_MODULE_REBOOT_REBOOT_FIELD = 'reboot'
DEFAULT_LINECARD_REBOOT_TIMEOUT = 180
DEFAULT_DPU_REBOOT_TIMEOUT = 360
PLATFORM_ENV_CONF_FILE = "/usr/share/sonic/platform/platform_env.conf"
PLATFORM_JSON_FILE = "/usr/share/sonic/platform/platform.json"

CHASSIS_INFO_UPDATE_PERIOD_SECS = 10
CHASSIS_DB_CLEANUP_MODULE_DOWN_PERIOD = 30 # Minutes

CHASSIS_LOAD_ERROR = 1
CHASSIS_NOT_SUPPORTED = 2

SELECT_TIMEOUT = 1000

NOT_AVAILABLE = 'N/A'
INVALID_SLOT = ModuleBase.MODULE_INVALID_SLOT
INVALID_MODULE_INDEX = -1
INVALID_IP = '0.0.0.0'

CHASSIS_MODULE_ADMIN_STATUS = 'admin_status'
MODULE_ADMIN_DOWN = 0
MODULE_ADMIN_UP = 1
MODULE_REBOOT_CAUSE_DIR = "/host/reboot-cause/module/"
MAX_HISTORY_FILES = 10

# This daemon should return non-zero exit code so that supervisord will
# restart it automatically.
exit_code = 0

#
# Helper functions =============================================================
#

# try get information from platform API and return a default value if caught NotImplementedError


def try_get(callback, *args, **kwargs):
    """
    Handy function to invoke the callback and catch NotImplementedError
    :param callback: Callback to be invoked
    :param args: Arguments to be passed to callback
    :param kwargs: Default return value if exception occur
    :return: Default return value if exception occur else return value of the callback
    """
    default = kwargs.get('default', NOT_AVAILABLE)
    try:
        ret = callback(*args)
        if ret is None:
            ret = default
    except NotImplementedError:
        ret = default

    return ret

def get_chassis():
    try:
        import sonic_platform.platform
        return sonic_platform.platform.Platform().get_chassis()
    except Exception as e:
        self.log_error("Failed to load chassis due to {}".format(repr(e)))
        sys.exit(CHASSIS_LOAD_ERROR)

def get_formatted_time(datetimeobj=None, op_format=None):
    """
    Get the current time in specified format
    :param datetimeobj: Optional - A datetime object already initialized with a specific time
    :param op_format: Optional - Output Format for the time to be displayed
    :returns time in string format
    """
    date_obj = datetimeobj if datetimeobj else datetime.now(timezone.utc)
    return date_obj.strftime(op_format if op_format else "%a %b %d %I:%M:%S %p UTC %Y")

#
# Module Config Updater ========================================================
#


class ModuleConfigUpdater(logger.Logger):

    def __init__(self, log_identifier, chassis):
        """
        Constructor for ModuleConfigUpdater
        :param chassis: Object representing a platform chassis
        """
        super(ModuleConfigUpdater, self).__init__(log_identifier)

        self.chassis = chassis

    def deinit(self):
        """
        Destructor of ModuleConfigUpdater
        :return:
        """

    def module_config_update(self, key, admin_state):
        if not key.startswith(ModuleBase.MODULE_TYPE_SUPERVISOR) and \
           not key.startswith(ModuleBase.MODULE_TYPE_LINE) and \
           not key.startswith(ModuleBase.MODULE_TYPE_FABRIC):
            self.log_error("Incorrect module-name {}. Should start with {} or {} or {}".format(key,
                                                                                               ModuleBase.MODULE_TYPE_SUPERVISOR,
                                                                                               ModuleBase.MODULE_TYPE_LINE,
                                                                                               ModuleBase.MODULE_TYPE_FABRIC))
            return

        module_index = try_get(self.chassis.get_module_index, key, default=INVALID_MODULE_INDEX)

        # Continue if the index is invalid
        if module_index < 0:
            self.log_error("Unable to get module-index for key {} to set admin-state {}". format(key, admin_state))
            return

        if (admin_state == MODULE_ADMIN_DOWN) or (admin_state == MODULE_ADMIN_UP):
            # Setting the module to administratively up/down state
            self.log_info("Changing module {} to admin {} state".format(key, 'DOWN' if admin_state == MODULE_ADMIN_DOWN else 'UP'))
            try_get(self.chassis.get_module(module_index).set_admin_state, admin_state, default=False)

#
# SmartSwitch Module Config Updater ========================================================
#


class SmartSwitchModuleConfigUpdater(logger.Logger):

    def __init__(self, log_identifier, chassis):
        """
        Constructor for SmartSwitchModuleConfigUpdater
        :param chassis: Object representing a platform chassis
        """
        super(SmartSwitchModuleConfigUpdater, self).__init__(log_identifier)

        self.chassis = chassis

    def deinit(self):
        """
        Destructor of SmartSwitchModuleConfigUpdater
        :return:
        """

    def module_config_update(self, key, admin_state):
        if not key.startswith(ModuleBase.MODULE_TYPE_DPU):
            self.log_error("Incorrect module-name {}. Should start with {}".format(key,
                                                        ModuleBase.MODULE_TYPE_DPU))
            return

        module_index = try_get(self.chassis.get_module_index, key, default=INVALID_MODULE_INDEX)

        # Continue if the index is invalid
        if module_index < 0:
            self.log_error("Unable to get module-index for key {} to set admin-state {}". format(key, admin_state))
            return

        if (admin_state == MODULE_ADMIN_DOWN) or (admin_state == MODULE_ADMIN_UP):
            self.log_info("Changing module {} to admin {} state".format(key, 'DOWN' if admin_state == MODULE_ADMIN_DOWN else 'UP'))
            t = threading.Thread(target=self.submit_callback, args=(module_index, admin_state))
            t.start()
        else:
            self.log_warning("Invalid admin_state value: {}".format(admin_state))

    def submit_callback(self, module_index, admin_state):
        try_get(self.chassis.get_module(module_index).set_admin_state, admin_state, default=False)
        pass

#
# Module Updater ==============================================================
#


class ModuleUpdater(logger.Logger):

    def __init__(self, log_identifier, chassis, my_slot, supervisor_slot):
        """
        Constructor for ModuleUpdater
        :param chassis: Object representing a platform chassis
        """
        super(ModuleUpdater, self).__init__(log_identifier)

        self.chassis = chassis
        self.my_slot = my_slot
        self.supervisor_slot = supervisor_slot
        self.num_modules = chassis.get_num_modules()
        # Connect to STATE_DB and create chassis info tables
        state_db = daemon_base.db_connect("STATE_DB")
        self.chassis_table = swsscommon.Table(state_db, CHASSIS_INFO_TABLE)
        self.module_table = swsscommon.Table(state_db, CHASSIS_MODULE_INFO_TABLE)
        self.midplane_table = swsscommon.Table(state_db, CHASSIS_MIDPLANE_INFO_TABLE)
        self.info_dict_keys = [CHASSIS_MODULE_INFO_NAME_FIELD,
                               CHASSIS_MODULE_INFO_DESC_FIELD,
                               CHASSIS_MODULE_INFO_SLOT_FIELD,
                               CHASSIS_MODULE_INFO_SERIAL_FIELD,
                               CHASSIS_MODULE_INFO_OPERSTATUS_FIELD]

        self.chassis_state_db = daemon_base.db_connect("CHASSIS_STATE_DB")
        if self._is_supervisor():
            self.asic_table = swsscommon.Table(self.chassis_state_db, 
                                            CHASSIS_FABRIC_ASIC_INFO_TABLE)
        else:
            self.asic_table = swsscommon.Table(self.chassis_state_db, 
                                            CHASSIS_ASIC_INFO_TABLE)

        self.hostname_table = swsscommon.Table(self.chassis_state_db, CHASSIS_MODULE_HOSTNAME_TABLE)
        self.module_reboot_table = swsscommon.Table(self.chassis_state_db, CHASSIS_MODULE_REBOOT_INFO_TABLE) 
        self.down_modules = {}
        self.chassis_app_db_clean_sha = None

        self.linecard_reboot_timeout = DEFAULT_LINECARD_REBOOT_TIMEOUT
        if os.path.isfile(PLATFORM_ENV_CONF_FILE):
            with open(PLATFORM_ENV_CONF_FILE, 'r') as file:
                for line in file:
                    field = line.split('=')[0].strip()
                    if field == "linecard_reboot_timeout":
                        self.linecard_reboot_timeout = int(line.split('=')[1].strip())
                        
        self.midplane_initialized = try_get(chassis.init_midplane_switch, default=False)
        if not self.midplane_initialized:
            self.log_error("Chassisd midplane intialization failed")

    def deinit(self):
        """
        Destructor of ModuleUpdater
        :return:
        """
        # Delete all the information from DB and then exit
        for module_index in range(0, self.num_modules):
            name = try_get(self.chassis.get_module(module_index).get_name)
            self.module_table._del(name)
            if self.midplane_table.get(name) is not None:
                self.midplane_table._del(name)

        if self.chassis_table is not None:
            self.chassis_table._del(CHASSIS_INFO_KEY_TEMPLATE.format(1))

        if self.asic_table is not None:
            if not self._is_supervisor():
                asics = list(self.asic_table.getKeys())
                for asic in asics:
                    self.asic_table._del(asic)

    def modules_num_update(self):
        # Check if module list is populated
        num_modules = self.chassis.get_num_modules()
        if num_modules == 0:
            self.log_error("Chassisd has no modules available")
            return

        # Post number-of-modules info to STATE_DB
        fvs = swsscommon.FieldValuePairs([(CHASSIS_INFO_CARD_NUM_FIELD, str(num_modules))])
        self.chassis_table.set(CHASSIS_INFO_KEY_TEMPLATE.format(1), fvs)

    def get_module_current_status(self, key):
        fvs = self.module_table.get(key)
        if isinstance(fvs, list) and fvs[0] is True:
            fvs = dict(fvs[-1])
            return fvs[CHASSIS_MODULE_INFO_OPERSTATUS_FIELD]
        return ModuleBase.MODULE_STATUS_EMPTY

    def get_module_admin_status(self, chassis_module_name):
        config_db = daemon_base.db_connect("CONFIG_DB")
        vtable = swsscommon.Table(config_db, CHASSIS_CFG_TABLE)
        fvs = vtable.get(chassis_module_name)
        if isinstance(fvs, list) and fvs[0] is True:
            fvs = dict(fvs[-1])
            return fvs[CHASSIS_MODULE_ADMIN_STATUS]
        else:
            return 'up'

    def module_db_update(self):
        notOnlineModules = []
        my_index = None

        for module_index in range(0, self.num_modules):
            module_info_dict = self._get_module_info(module_index)
            if self.my_slot == module_info_dict['slot']:
                my_index = module_index

            if module_info_dict is not None:
                key = module_info_dict[CHASSIS_MODULE_INFO_NAME_FIELD]

                if not key.startswith(ModuleBase.MODULE_TYPE_SUPERVISOR) and \
                   not key.startswith(ModuleBase.MODULE_TYPE_LINE) and \
                   not key.startswith(ModuleBase.MODULE_TYPE_FABRIC):
                    self.log_error("Incorrect module-name {}. Should start with {} or {} or {}".format(key,
                                                                                                       ModuleBase.MODULE_TYPE_SUPERVISOR,
                                                                                                       ModuleBase.MODULE_TYPE_LINE,
                                                                                                       ModuleBase.MODULE_TYPE_FABRIC))
                    continue

                fvs = swsscommon.FieldValuePairs([(CHASSIS_MODULE_INFO_DESC_FIELD, module_info_dict[CHASSIS_MODULE_INFO_DESC_FIELD]),
                                                  (CHASSIS_MODULE_INFO_SLOT_FIELD,
                                                   str(module_info_dict[CHASSIS_MODULE_INFO_SLOT_FIELD])),
                                                  (CHASSIS_MODULE_INFO_OPERSTATUS_FIELD, module_info_dict[CHASSIS_MODULE_INFO_OPERSTATUS_FIELD]),
                                                  (CHASSIS_MODULE_INFO_NUM_ASICS_FIELD, str(len(module_info_dict[CHASSIS_MODULE_INFO_ASICS]))),
                                                  (CHASSIS_MODULE_INFO_SERIAL_FIELD, module_info_dict[CHASSIS_MODULE_INFO_SERIAL_FIELD])])
                prev_status = self.get_module_current_status(key)
                self.module_table.set(key, fvs)

                # Construct key for down_modules dict. Example down_modules key format: LINE-CARD0|<hostname>
                fvs = self.hostname_table.get(key)
                if isinstance(fvs, list) and fvs[0] is True:
                    fvs = dict(fvs[-1])
                    hostname = fvs[CHASSIS_MODULE_INFO_HOSTNAME_FIELD]
                    down_module_key = key+'|'+hostname
                else:
                    down_module_key = key+'|'

                if module_info_dict[CHASSIS_MODULE_INFO_OPERSTATUS_FIELD] != str(ModuleBase.MODULE_STATUS_ONLINE):
                    if prev_status == ModuleBase.MODULE_STATUS_ONLINE:
                        notOnlineModules.append(key)
                        # Record the time when the module down was detected to track the
                        # module down time. Used for chassis db cleanup for all asics of the module if the module is down for a 
                        # long time like 30 mins.
                        # All down modules including supervisor are added to the down modules dictionary. This is to help
                        # identifying module operational status change. But the clean up will not be attempted for supervisor

                        if down_module_key not in self.down_modules:
                            self.log_warning("Module {} (Slot {}) went off-line!".format(key, module_info_dict[CHASSIS_MODULE_INFO_SLOT_FIELD]))
                            self.down_modules[down_module_key] = {}
                            self.down_modules[down_module_key]['down_time'] = time.time()
                            self.down_modules[down_module_key]['cleaned'] = False
                            self.down_modules[down_module_key]['slot'] = module_info_dict[CHASSIS_MODULE_INFO_SLOT_FIELD]
                    continue
                else:
                    # Module is operational. Remove it from down time tracking.
                    if down_module_key in self.down_modules:
                        self.log_notice("Module {} (Slot {}) recovered on-line!".format(key, module_info_dict[CHASSIS_MODULE_INFO_SLOT_FIELD]))
                        del self.down_modules[down_module_key]
                    elif prev_status != ModuleBase.MODULE_STATUS_ONLINE:
                        self.log_notice("Module {} (Slot {}) is on-line!".format(key, module_info_dict[CHASSIS_MODULE_INFO_SLOT_FIELD] ))

                    module_cfg_status = self.get_module_admin_status(key)

                    #Only populate the related tables when the module configure is up
                    if module_cfg_status != 'down':
                        for asic_id, asic in enumerate(module_info_dict[CHASSIS_MODULE_INFO_ASICS]):
                            asic_global_id, asic_pci_addr = asic
                            asic_key = "%s%s" % (CHASSIS_ASIC, asic_global_id)
                            if not self._is_supervisor():
                                asic_key = "%s|%s" % (key, asic_key)

                            asic_fvs = swsscommon.FieldValuePairs([(CHASSIS_ASIC_PCI_ADDRESS_FIELD, asic_pci_addr),
                                                                    (CHASSIS_MODULE_INFO_NAME_FIELD, key),
                                                                    (CHASSIS_ASIC_ID_IN_MODULE_FIELD, str(asic_id))])
                            self.asic_table.set(asic_key, asic_fvs)

        # In line card push the hostname of the module and num_asics to the chassis state db.
        # The hostname is used as key to access chassis app db entries 
        if not self._is_supervisor():
           module_info_dict = self._get_module_info(my_index)
           hostname_key = "{}{}".format(ModuleBase.MODULE_TYPE_LINE, int(self.my_slot) - 1)
           hostname = try_get(device_info.get_hostname, default="None")
           hostname_fvs = swsscommon.FieldValuePairs([(CHASSIS_MODULE_INFO_SLOT_FIELD, str(self.my_slot)), 
                                                        (CHASSIS_MODULE_INFO_HOSTNAME_FIELD, hostname),
                                                        (CHASSIS_MODULE_INFO_NUM_ASICS_FIELD, str(len(module_info_dict[CHASSIS_MODULE_INFO_ASICS])))])
           self.hostname_table.set(hostname_key, hostname_fvs)

        # Asics that are on the "not online" modules need to be cleaned up
        if notOnlineModules:
            asics = list(self.asic_table.getKeys())
            for asic in asics:
                fvs = self.asic_table.get(asic)
                if isinstance(fvs, list):
                    fvs = dict(fvs[-1])
                if CHASSIS_MODULE_INFO_NAME_FIELD in fvs.keys() and fvs[CHASSIS_MODULE_INFO_NAME_FIELD] in notOnlineModules:
                    self.asic_table._del(asic)

    def _get_module_info(self, module_index):
        """
        Retrieves module info of this module
        """
        module_info_dict = {}
        module_info_dict = dict.fromkeys(self.info_dict_keys, 'N/A')
        name = try_get(self.chassis.get_module(module_index).get_name)
        desc = try_get(self.chassis.get_module(module_index).get_description)
        slot = try_get(self.chassis.get_module(module_index).get_slot, default=INVALID_SLOT)
        status = try_get(self.chassis.get_module(module_index).get_oper_status,
                         default=ModuleBase.MODULE_STATUS_OFFLINE)
        asics = try_get(self.chassis.get_module(module_index).get_all_asics,
                        default=[])
        serial = try_get(self.chassis.get_module(module_index).get_serial)

        module_info_dict[CHASSIS_MODULE_INFO_NAME_FIELD] = name
        module_info_dict[CHASSIS_MODULE_INFO_DESC_FIELD] = str(desc)
        module_info_dict[CHASSIS_MODULE_INFO_SLOT_FIELD] = slot
        module_info_dict[CHASSIS_MODULE_INFO_OPERSTATUS_FIELD] = str(status)
        module_info_dict[CHASSIS_MODULE_INFO_ASICS] = asics
        module_info_dict[CHASSIS_MODULE_INFO_SERIAL_FIELD] = str(serial)

        return module_info_dict

    def _is_supervisor(self):
        if self.my_slot == self.supervisor_slot:
            return True
        else:
            return False

    def is_module_reboot_expected(self, key):
        fvs = self.module_reboot_table.get(key)
        if isinstance(fvs, list) and fvs[0] is True:
            fvs = dict(fvs[-1])
            if fvs[CHASSIS_MODULE_REBOOT_REBOOT_FIELD] == "expected":
                return True
        return False
    
    def module_reboot_set_time(self, key):
        time_now = time.time()
        fvs = swsscommon.FieldValuePairs([(CHASSIS_MODULE_REBOOT_TIMESTAMP_FIELD, str(time_now))])
        self.module_reboot_table.set(key,fvs)
        
    def is_module_reboot_system_up_expired(self, key):
        fvs = self.module_reboot_table.get(key)
        if isinstance(fvs, list) and fvs[0] is True:
            fvs = dict(fvs[-1])
            if CHASSIS_MODULE_REBOOT_TIMESTAMP_FIELD in fvs.keys():
                timestamp= float(fvs[CHASSIS_MODULE_REBOOT_TIMESTAMP_FIELD])
                time_now = time.time()
                if time_now - timestamp >= self.linecard_reboot_timeout:
                    self.module_reboot_table._del(key)
                    return True
        return False
    
    def check_midplane_reachability(self):
        if not self.midplane_initialized:
            return

        index = -1
        for module in self.chassis.get_all_modules():
            index += 1
            # Skip fabric cards
            if module.get_type() == ModuleBase.MODULE_TYPE_FABRIC:
                continue

            if self._is_supervisor():
                # On supervisor skip checking for supervisor
                if module.get_slot() == self.supervisor_slot:
                    continue
            else:
                # On line-card check only supervisor
                if module.get_slot() != self.supervisor_slot:
                    continue

            module_key = try_get(module.get_name, default='MODULE {}'.format(index))
            midplane_ip = try_get(module.get_midplane_ip, default=INVALID_IP)
            midplane_access = try_get(module.is_midplane_reachable, default=False)

            # Generate syslog for the loss of midplane connectivity when midplane connectivity
            # loss is detected for the first time
            current_midplane_state = 'False'
            fvs = self.midplane_table.get(module_key)
            if isinstance(fvs, list) and fvs[0] is True:
                fvs = dict(fvs[-1])
                current_midplane_state = fvs[CHASSIS_MIDPLANE_INFO_ACCESS_FIELD]

            if midplane_access is False and current_midplane_state == 'True':
                if self.is_module_reboot_expected(module_key):
                    self.module_reboot_set_time(module_key)
                    self.log_warning("Expected: Module {} (Slot {}) lost midplane connectivity".format(module_key, module.get_slot()))
                else:
                    self.log_warning("Unexpected: Module {} (Slot {}) lost midplane connectivity".format(module_key, module.get_slot()))
            elif midplane_access is True and current_midplane_state == 'False':
                self.log_notice("Module {} (Slot {}) midplane connectivity is up".format(module_key, module.get_slot()))
                # clean up the reboot_info_table
                if self.module_reboot_table.get(module_key) is not None:
                    self.module_reboot_table._del(module_key)
            elif midplane_access is False and current_midplane_state == 'False':
                if self.is_module_reboot_system_up_expired(module_key):
                    self.log_warning("Unexpected: Module {} (Slot {}) midplane connectivity is not restored in {} seconds".format(module_key, module.get_slot(), self.linecard_reboot_timeout))
                    
            # Update db with midplane information
            fvs = swsscommon.FieldValuePairs([(CHASSIS_MIDPLANE_INFO_IP_FIELD, midplane_ip),
                                              (CHASSIS_MIDPLANE_INFO_ACCESS_FIELD, str(midplane_access))])
            self.midplane_table.set(module_key, fvs)

    def _cleanup_chassis_app_db(self, module_host):

        if self.chassis_app_db_clean_sha is None:
            self.chassis_app_db = daemon_base.db_connect("CHASSIS_APP_DB")
            self.chassis_app_db_pipe = swsscommon.RedisPipeline(self.chassis_app_db)

            # Lua script for chassis db cleanup for a specific asic
            # The clean up operation is required to delete only those entries created by
            # the asic that lost connection. Entries from the following tables are deleted
            #   (1) SYSTEM_NEIGH
            #   (2) SYSTEM_INTERFACE
            #   (3) SYSTEM_LAG_MEMBER_TABLE
            #   (4) SYSTEM_LAG_TABLE
            #   (5) The corresponding LAG IDs of the entries from SYSTEM_LAG_TABLE
            #       SYSTEM_LAG_ID_TABLE and SYSTEM_LAG_ID_SET are adjusted appropriately

            script = "local host = string.gsub(ARGV[1], '%-', '%%-')\n\
        local dev = ARGV[2]\n\
        local tables = {'SYSTEM_NEIGH*', 'SYSTEM_INTERFACE*', 'SYSTEM_LAG_MEMBER_TABLE*'}\n\
        for i = 1, table.getn(tables) do\n\
            local ps = tables[i] .. '|' .. host .. '|' .. dev\n\
            local keylist = redis.call('KEYS', tables[i])\n\
            for j,key in ipairs(keylist) do\n\
                if string.match(key, ps) ~= nil then\n\
                    redis.call('DEL', key)\n\
                end\n\
            end\n\
        end\n\
        local ps = 'SYSTEM_LAG_TABLE*|' .. '(' .. host .. '|' .. dev ..'.*' .. ')'\n\
        local keylist = redis.call('KEYS', 'SYSTEM_LAG_TABLE*')\n\
        for j,key in ipairs(keylist) do\n\
            local lagname = string.match(key, ps)\n\
            if lagname ~= nil then\n\
                redis.call('DEL', key)\n\
                local lagid = redis.call('HGET', 'SYSTEM_LAG_ID_TABLE', lagname)\n\
                redis.call('SREM', 'SYSTEM_LAG_ID_SET', lagid)\n\
                redis.call('HDEL', 'SYSTEM_LAG_ID_TABLE', lagname)\n\
                redis.call('rpush', 'SYSTEM_LAG_IDS_FREE_LIST', lagid)\n\
            end\n\
        end\n\
        return"
            self.chassis_app_db_clean_sha = self.chassis_app_db_pipe.loadRedisScript(script)

        # Chassis app db cleanup of all asics of the module

        # Get the module key and host name from down_modules key
        module, lc = re.split('\|', module_host)

        if lc == '':
            # Host name is not available for this module. No clean up is needed
            self.log_notice("Host name is not available for Module {}. Chassis db clean up not done!".format(module))
            return

        # Get number of asics in the module
        fvs = self.hostname_table.get(module)
        if isinstance(fvs, list) and fvs[0] is True:
            fvs = dict(fvs[-1])
            num_asics = int(fvs[CHASSIS_MODULE_INFO_NUM_ASICS_FIELD])
        else:
            num_asics = 0

        for asic_id in range(0, num_asics):
            asic = CHASSIS_ASIC+str(asic_id)

            # Cleanup the chassis app db entries using lua script
            redis_cmd = ['redis-cli', '-h', 'redis_chassis.server', '-p', '6380', '-n', '12', 'EVALSHA', self.chassis_app_db_clean_sha, '0', lc, asic]
            try:
                subp = subprocess.Popen(redis_cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
                subp.communicate()
                self.log_notice("Cleaned up chassis app db entries for {}({})/{}".format(module, lc, asic))
            except Exception:
                self.log_error("Failed to clean up chassis app db entries for {}({})/{}".format(module, lc, asic))


    def module_down_chassis_db_cleanup(self):
        if self._is_supervisor() == False:
            return
        time_now = time.time()
        for module in self.down_modules:
            if self.down_modules[module]['cleaned'] == False:
                down_time = self.down_modules[module]['down_time']
                slot = self.down_modules[module]['slot']
                delta = (time_now - down_time) / 60
                if delta >= CHASSIS_DB_CLEANUP_MODULE_DOWN_PERIOD:
                    if module.startswith(ModuleBase.MODULE_TYPE_LINE):
                        # Module is down for more than 30 minutes. Do the chassis clean up
                        self.log_notice("Module {} (Slot {}) is down for long time. Initiating chassis app db clean up".format(module, slot))
                        self._cleanup_chassis_app_db(module)
                    self.down_modules[module]['cleaned'] = True

#
# Module Updater ==============================================================
#


class SmartSwitchModuleUpdater(ModuleUpdater):

    def __init__(self, log_identifier, chassis):
        """
        Constructor for ModuleUpdater
        :param chassis: Object representing a platform chassis
        """
        super(ModuleUpdater, self).__init__(log_identifier)

        self.chassis = chassis
        self.num_modules = self.chassis.get_num_modules()
        # Connect to STATE_DB and create chassis info tables
        state_db = daemon_base.db_connect("STATE_DB")
        self.chassis_table = swsscommon.Table(state_db, CHASSIS_INFO_TABLE)
        self.module_table = swsscommon.Table(state_db, CHASSIS_MODULE_INFO_TABLE)
        self.midplane_table = swsscommon.Table(state_db, CHASSIS_MIDPLANE_INFO_TABLE)
        self.info_dict_keys = [CHASSIS_MODULE_INFO_NAME_FIELD,
                               CHASSIS_MODULE_INFO_DESC_FIELD,
                               CHASSIS_MODULE_INFO_SLOT_FIELD,
                               CHASSIS_MODULE_INFO_SERIAL_FIELD,
                               CHASSIS_MODULE_INFO_OPERSTATUS_FIELD]

        self.chassis_state_db = daemon_base.db_connect("CHASSIS_STATE_DB")

        self.hostname_table = swsscommon.Table(self.chassis_state_db, CHASSIS_MODULE_HOSTNAME_TABLE)
        self.module_reboot_table = swsscommon.Table(self.chassis_state_db, CHASSIS_MODULE_REBOOT_INFO_TABLE)
        self.down_modules = {}
        self.chassis_app_db_clean_sha = None

        self.midplane_initialized = try_get(chassis.init_midplane_switch, default=False)
        if not self.midplane_initialized:
            self.log_error("Chassisd midplane intialization failed")

        self.dpu_reboot_timeout = DEFAULT_DPU_REBOOT_TIMEOUT
        if os.path.isfile(PLATFORM_JSON_FILE):
            try:
                with open(PLATFORM_JSON_FILE, 'r') as f:
                    platform_cfg = json.load(f)
                # Extract the "dpu_reboot_timeout" if it exists
                self.dpu_reboot_timeout = int(platform_cfg.get("dpu_reboot_timeout", DEFAULT_DPU_REBOOT_TIMEOUT))
            except (json.JSONDecodeError, ValueError) as e:
                self.log_error("Error parsing {}: {}".format(PLATFORM_JSON_FILE, e))
            except Exception as e:
                self.log_error("Unexpected error: {}".format(e))

    def deinit(self):
        """
        Destructor of ModuleUpdater
        :return:
        """
        # Delete all the information from DB and then exit
        for module_index in range(0, self.num_modules):
            name = try_get(self.chassis.get_module(module_index).get_name)
            self.module_table._del(name)
            if self.midplane_table.get(name) is not None:
                self.midplane_table._del(name)

        if self.chassis_table is not None:
            self.chassis_table._del(CHASSIS_INFO_KEY_TEMPLATE.format(1))

    def get_module_admin_status(self, chassis_module_name):
        config_db = daemon_base.db_connect("CONFIG_DB")
        vtable = swsscommon.Table(config_db, CHASSIS_CFG_TABLE)
        fvs = vtable.get(chassis_module_name)
        if isinstance(fvs, list) and fvs[0] is True:
            fvs = dict(fvs[-1])
            return fvs[CHASSIS_MODULE_ADMIN_STATUS]
        else:
            return 'empty'

    def module_db_update(self):
        for module_index in range(0, self.num_modules):
            module_info_dict = self._get_module_info(module_index)
            if module_info_dict is not None:
                key = module_info_dict[CHASSIS_MODULE_INFO_NAME_FIELD]

                if not key.startswith(ModuleBase.MODULE_TYPE_DPU):
                    self.log_error("Incorrect module-name {}. Should start with {} ".format(key,
                                                            ModuleBase.MODULE_TYPE_DPU))
                    continue

                fvs = swsscommon.FieldValuePairs([(CHASSIS_MODULE_INFO_DESC_FIELD, module_info_dict[CHASSIS_MODULE_INFO_DESC_FIELD]),
                                                  (CHASSIS_MODULE_INFO_SLOT_FIELD, module_info_dict[CHASSIS_MODULE_INFO_SLOT_FIELD]),
                                                  (CHASSIS_MODULE_INFO_OPERSTATUS_FIELD, module_info_dict[CHASSIS_MODULE_INFO_OPERSTATUS_FIELD]),
                                                  (CHASSIS_MODULE_INFO_SERIAL_FIELD, module_info_dict[CHASSIS_MODULE_INFO_SERIAL_FIELD])])

                # Get a copy of the previous operational status of the module
                prev_status = self.get_module_current_status(key)
                self.module_table.set(key, fvs)

                # Get a copy of the current operational status of the module
                current_status = module_info_dict[CHASSIS_MODULE_INFO_OPERSTATUS_FIELD]

                # Operational status transitioning to offline
                if prev_status != str(ModuleBase.MODULE_STATUS_OFFLINE) and current_status == str(ModuleBase.MODULE_STATUS_OFFLINE):
                    self.log_notice("{} operational status transitioning to offline".format(key))

                    # Persist dpu down time
                    self.persist_dpu_reboot_time(key)

                elif prev_status == str(ModuleBase.MODULE_STATUS_OFFLINE) and current_status != str(ModuleBase.MODULE_STATUS_OFFLINE):
                    self.log_notice("{} operational status transitioning to online".format(key))
                    reboot_cause = try_get(self.chassis.get_module(module_index).get_reboot_cause)

                    if not self.retrieve_dpu_reboot_time(key) is None or self._is_first_boot(key):
                        # persist reboot cause
                        self.persist_dpu_reboot_cause(reboot_cause, key)
                        # publish reboot cause to db
                        self.update_dpu_reboot_cause_to_db(key)

    def _get_module_info(self, module_index):
        """
        Retrieves module info of this module
        """
        module_info_dict = {}
        module_info_dict = dict.fromkeys(self.info_dict_keys, 'N/A')
        name = try_get(self.chassis.get_module(module_index).get_name)
        desc = try_get(self.chassis.get_module(module_index).get_description)
        status = try_get(self.chassis.get_module(module_index).get_oper_status,
                         default=ModuleBase.MODULE_STATUS_OFFLINE)
        asics = try_get(self.chassis.get_module(module_index).get_all_asics,
                        default=[])
        serial = try_get(self.chassis.get_module(module_index).get_serial)

        module_info_dict[CHASSIS_MODULE_INFO_NAME_FIELD] = name
        module_info_dict[CHASSIS_MODULE_INFO_DESC_FIELD] = str(desc)
        module_info_dict[CHASSIS_MODULE_INFO_SLOT_FIELD] = 'N/A'
        module_info_dict[CHASSIS_MODULE_INFO_OPERSTATUS_FIELD] = str(status)
        module_info_dict[CHASSIS_MODULE_INFO_SERIAL_FIELD] = str(serial)

        return module_info_dict


    def update_dpu_state(self, key, state):
        """
        Update specific DPU state fields in chassisStateDB using the given key.
        """
        try:
            # Connect to the CHASSIS_STATE_DB using daemon_base
            if not self.chassis_state_db:
                self.chassis_state_db = daemon_base.db_connect("CHASSIS_STATE_DB")

            # Prepare the fields to update
            updates = {
                "dpu_midplane_link_state": state,
                "dpu_midplane_link_reason": "",
                "dpu_midplane_link_time": get_formatted_time(),
            }

            # Update each field directly
            for field, value in updates.items():
                self.chassis_state_db.hset(key, field, value)

        except Exception as e:
            self.log_error(f"Unexpected error: {e}")

    def get_dpu_midplane_state(self, key):
        """
        Get DPU midplane-state from chassisStateDB using the given key.
        """
        try:
            # Connect to the CHASSIS_STATE_DB using daemon_base
            if not self.chassis_state_db:
                self.chassis_state_db = daemon_base.db_connect("CHASSIS_STATE_DB")

            # Fetch the dpu_midplane_link_state
            return self.chassis_state_db.hget(key, "dpu_midplane_link_state")

        except Exception as e:
            self.log_error(f"Unexpected error: {e}")

    def _convert_to_dict(self, data):
        """
        Converts SWIG proxy object or native dict to a Python dictionary.
        """
        if isinstance(data, dict):
            return data  # Already a dict, return as-is
        else:
            return dict(data)  # Convert SWIG proxy object to dict

    def _get_current_time_str(self):
        """Returns the current time as a string in 'YYYY_MM_DD_HH_MM_SS' format."""
        return get_formatted_time(op_format="%Y_%m_%d_%H_%M_%S")

    def _get_history_path(self, module, file_name):
        """Generates the full path for history files."""
        return os.path.join(MODULE_REBOOT_CAUSE_DIR, module.lower(), "history", file_name)

    def _is_first_boot(self, module):
        """Checks if the reboot-cause file indicates a first boot."""
        file_path = os.path.join(MODULE_REBOOT_CAUSE_DIR, module.lower(), "reboot-cause.txt")

        try:
            with open(file_path, 'r') as f:
                content = f.read().strip()
                return content == "First boot"
        except FileNotFoundError:
                return False

    def persist_dpu_reboot_time(self, module):
        """Persist the current reboot time to a file."""
        time_str = self._get_current_time_str()
        path = self._get_history_path(module, "prev_reboot_time.txt")

        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'w') as f:
            f.write(time_str)

    def retrieve_dpu_reboot_time(self, module):
        """Retrieve the persisted reboot time from a file."""
        path = self._get_history_path(module, "prev_reboot_time.txt")

        try:
            with open(path, 'r') as f:
                return f.read().strip()
        except FileNotFoundError:
            return None

    def persist_dpu_reboot_cause(self, reboot_cause, module):
        """Persist the reboot cause information and handle file rotation."""
        # Extract cause and comment from the reboot_cause
        if reboot_cause:
            try:
                cause, comment = (
                    reboot_cause.split(",", 1) if isinstance(reboot_cause, str) else reboot_cause
                )
            except ValueError:
                cause = reboot_cause if isinstance(reboot_cause, str) else "Unknown"
                comment = "N/A"
        else:
            cause, comment = "Unknown", "N/A"

        prev_reboot_time = self.retrieve_dpu_reboot_time(module)
        if prev_reboot_time is None:
            prev_reboot_time = self._get_current_time_str()

        file_name = f"{prev_reboot_time}_reboot_cause.txt"
        prev_reboot_path = self._get_history_path(module, "prev_reboot_time.txt")

        if os.path.exists(prev_reboot_path):
            os.remove(prev_reboot_path)

        file_path = self._get_history_path(module, file_name)
        try:
            formatted_time = get_formatted_time(datetimeobj=datetime.strptime(prev_reboot_time, "%Y_%m_%d_%H_%M_%S"))
        except ValueError:
            formatted_time = get_formatted_time()

        reboot_cause_dict = {
            "cause": cause,
            "comment": comment,
            "device": module,
            "time": formatted_time,
            "name": prev_reboot_time,
        }

        with open(file_path, 'w') as f:
            json.dump(reboot_cause_dict, f)

        # Write the reboot_cause content to the reboot-cause.txt file, overwriting it
        reboot_cause_path = os.path.join(MODULE_REBOOT_CAUSE_DIR, module.lower(), "reboot-cause.txt")
        os.makedirs(os.path.dirname(reboot_cause_path), exist_ok=True)
        with open(reboot_cause_path, 'w') as cause_file:
            cause_file.write(json.dumps(reboot_cause) + '\n')

        # Update symlink to the latest reboot cause file
        symlink_path = os.path.join(MODULE_REBOOT_CAUSE_DIR, module.lower(), "previous-reboot-cause.json")
        if os.path.exists(symlink_path):
            os.remove(symlink_path)
        if os.path.exists(file_path):
            os.symlink(file_path, symlink_path)

        # Perform file rotation if necessary
        self._rotate_files(module)

    def _rotate_files(self, module):
        """Rotate history files if they exceed the maximum limit."""
        history_dir = os.path.join(MODULE_REBOOT_CAUSE_DIR, module.lower(), "history")
        os.makedirs(history_dir, exist_ok=True)
        try:
            files = sorted(os.listdir(history_dir))
        except FileNotFoundError:
            return

        if not files:
            return

        if len(files) > MAX_HISTORY_FILES:
            for old_file in files[:-MAX_HISTORY_FILES]:
                os.remove(os.path.join(history_dir, old_file))

    def update_dpu_reboot_cause_to_db(self, module):
        """Update the reboot cause in CHASSIS_STATE_DB for a given module."""

        # Ensure the DB connection is active
        if not self.chassis_state_db:
            self.chassis_state_db = daemon_base.db_connect("CHASSIS_STATE_DB")

        # Delete existing keys for the module in CHASSIS_STATE_DB
        pattern = f"REBOOT_CAUSE|{module.upper()}|*"
        keys = self.chassis_state_db.keys(pattern)
        if keys:
            for key in keys:
                self.chassis_state_db.delete(key)

        # Fetch the list of reboot cause history files
        history_path = f"/host/reboot-cause/module/{module.lower()}/history/*_reboot_cause.txt"
        reboot_cause_files = glob.glob(history_path)

        if not reboot_cause_files:
            self.log_warning(f"No reboot cause history files found for module: {module}")
            return

        # Iterate over each reboot cause file and store data in CHASSIS_STATE_DB
        for file_path in reboot_cause_files:
            try:
                with open(file_path, "r") as file:
                    reboot_cause_dict = json.load(file)

                if not reboot_cause_dict:
                    self.log_warning(f"{module} reboot_cause_dict is empty")
                    continue

                # Generate the key based on module and reboot time
                reboot_time = reboot_cause_dict.get("name", self._get_current_time_str())
                key = f"REBOOT_CAUSE|{module.upper()}|{reboot_time}"

                # Publish the reboot cause information to CHASSIS_STATE_DB
                for field, value in reboot_cause_dict.items():
                    if field and value is not None:
                        self.chassis_state_db.hset(key, field, value)

            except json.JSONDecodeError:
                self.log_warning(f"Failed to decode JSON from file: {file_path}")
            except Exception as e:
                self.log_warning(f"Error processing file {file_path}: {e}")

    def check_midplane_reachability(self):
        if not self.midplane_initialized:
            return

        index = -1
        for module in self.chassis.get_all_modules():
            index += 1

            module_key = try_get(module.get_name, default='MODULE {}'.format(index))
            midplane_ip = try_get(module.get_midplane_ip, default=INVALID_IP)
            midplane_access = try_get(module.is_midplane_reachable, default=False)
            # Generate syslog for the loss of midplane connectivity when midplane connectivity
            # loss is detected for the first time
            current_midplane_state = 'False'
            fvs = self.midplane_table.get(module_key)
            if isinstance(fvs, list) and fvs[0] is True:
                fvs = dict(fvs[-1])
                current_midplane_state = fvs[CHASSIS_MIDPLANE_INFO_ACCESS_FIELD]

            if midplane_access is False and current_midplane_state == 'True':
                self.log_warning("Unexpected: Module {} lost midplane connectivity".format(module_key))

            elif midplane_access is True and current_midplane_state == 'False':
                self.log_notice("Module {} midplane connectivity is up".format(module_key))

            # Update midplane state in the chassisStateDB DPU_STATE table
            key = "DPU_STATE|" + module_key
            dpu_mp_state = self.get_dpu_midplane_state(key)
            if midplane_access and dpu_mp_state != 'up':
                self.update_dpu_state(key, 'up')
            elif not midplane_access and dpu_mp_state != 'down':
                self.update_dpu_state(key, "down")

            # Update db with midplane information
            fvs = swsscommon.FieldValuePairs([(CHASSIS_MIDPLANE_INFO_IP_FIELD, midplane_ip),
                                              (CHASSIS_MIDPLANE_INFO_ACCESS_FIELD, str(midplane_access))])
            self.midplane_table.set(module_key, fvs)

    def module_down_chassis_db_cleanup(self):
        # cleanup CHASSIS_STATE_DB
        if not self.chassis_state_db:
            self.chassis_state_db = daemon_base.db_connect("CHASSIS_STATE_DB")

        for module_index in range(0, self.num_modules):
            name = try_get(self.chassis.get_module(module_index).get_name)
            pattern = "*" + name + "*"
            if self.get_module_admin_status(name) != 'up':
                keys = self.chassis_state_db.keys(pattern)
                if keys:
                    for key in keys:
                        if not "DPU_STATE" in key and not "REBOOT_CAUSE" in key:
                            self.chassis_state_db.delete(key)
        return


#
# Config Manager task ========================================================
#


class ConfigManagerTask(ProcessTaskBase):
    def __init__(self):
        ProcessTaskBase.__init__(self)

        # TODO: Refactor to eliminate the need for this Logger instance
        self.logger = logger.Logger(SYSLOG_IDENTIFIER)

    def task_worker(self):
        self.config_updater = ModuleConfigUpdater(SYSLOG_IDENTIFIER, get_chassis())
        config_db = daemon_base.db_connect("CONFIG_DB")

        # Subscribe to CHASSIS_MODULE table notifications in the Config DB
        sel = swsscommon.Select()
        sst = swsscommon.SubscriberStateTable(config_db, CHASSIS_CFG_TABLE)
        sel.addSelectable(sst)

        # Listen indefinitely for changes to the CFG_CHASSIS_MODULE_TABLE table in the Config DB
        while True:
            # Use timeout to prevent ignoring the signals we want to handle
            # in signal_handler() (e.g. SIGTERM for graceful shutdown)
            (state, c) = sel.select(SELECT_TIMEOUT)

            if state == swsscommon.Select.TIMEOUT:
                # Do not flood log when select times out
                continue
            if state != swsscommon.Select.OBJECT:
                self.logger.log_warning("sel.select() did not return swsscommon.Select.OBJECT")
                continue

            (key, op, fvp) = sst.pop()

            if op == 'SET':
                admin_state = MODULE_ADMIN_DOWN
            elif op == 'DEL':
                admin_state = MODULE_ADMIN_UP
            else:
                continue

            self.config_updater.module_config_update(key, admin_state)


#
# SmartSwitch Config Manager task ========================================================
#


class SmartSwitchConfigManagerTask(ProcessTaskBase):
    def __init__(self):
        ProcessTaskBase.__init__(self)

        # TODO: Refactor to eliminate the need for this Logger instance
        self.logger = logger.Logger(SYSLOG_IDENTIFIER)

    def task_worker(self):
        self.config_updater = SmartSwitchModuleConfigUpdater(SYSLOG_IDENTIFIER, get_chassis())
        config_db = daemon_base.db_connect("CONFIG_DB")

        # Subscribe to CHASSIS_MODULE table notifications in the Config DB
        sel = swsscommon.Select()
        sst = swsscommon.SubscriberStateTable(config_db, CHASSIS_CFG_TABLE)
        sel.addSelectable(sst)

        # Listen indefinitely for changes to the CFG_CHASSIS_MODULE_TABLE table in the Config DB
        while True:
            # Use timeout to prevent ignoring the signals we want to handle
            # in signal_handler() (e.g. SIGTERM for graceful shutdown)
            (state, c) = sel.select(SELECT_TIMEOUT)

            if state == swsscommon.Select.TIMEOUT:
                # Do not flood log when select times out
                continue
            if state != swsscommon.Select.OBJECT:
                self.logger.log_warning("sel.select() did not return swsscommon.Select.OBJECT")
                continue

            (key, op, fvp) = sst.pop()

            if op == 'SET':
                fvs = dict(fvp)
                admin_status = fvs.get('admin_status')
                if admin_status == 'up':
                    admin_state = MODULE_ADMIN_UP
                else:
                    admin_state = MODULE_ADMIN_DOWN
            elif op == 'DEL':
                admin_state = MODULE_ADMIN_DOWN
            else:
                continue

            self.config_updater.module_config_update(key, admin_state)

#
# State Manager task ========================================================
#

class DpuStateUpdater(logger.Logger):

    DP_STATE = 'dpu_data_plane_state'
    DP_UPDATE_TIME = 'dpu_data_plane_time'
    CP_STATE = 'dpu_control_plane_state'
    CP_UPDATE_TIME = 'dpu_control_plane_time'

    def __init__(self, log_identifier, chassis):
        super(DpuStateUpdater, self).__init__(log_identifier)

        self.chassis = chassis

        self.state_db = daemon_base.db_connect('STATE_DB')
        self.app_db = daemon_base.db_connect('APPL_DB')
        self.chassis_state_db = daemon_base.db_connect('CHASSIS_STATE_DB')

        self.config_db = swsscommon.ConfigDBConnector()
        self.config_db.connect()

        try:
            self.chassis.get_dataplane_state()
        except NotImplementedError:
            self._get_dp_state = self._get_data_plane_state_common
        else:
            self._get_dp_state = self.chassis.get_dataplane_state

        try:
            self.chassis.get_controlplane_state()
        except NotImplementedError:
            self._get_cp_state = self._get_control_plane_state_common
        else:
            self._get_cp_state = self.chassis.get_controlplane_state

        self.id = self.chassis.get_dpu_id()
        self.name = f'DPU{self.id}'

        self.dpu_state_table = swsscommon.Table(self.chassis_state_db, 'DPU_STATE')

    def _get_data_plane_state_common(self):
        port_table = swsscommon.Table(self.app_db, 'PORT_TABLE')

        for port in self.config_db.get_table('PORT'):
            status, oper_status = port_table.hget(port, 'oper_status')
            if not status or oper_status.lower() != 'up':
                return False

        return True

    def _get_control_plane_state_common(self):
        sysready_table = swsscommon.Table(self.state_db,'SYSTEM_READY')

        status, sysready_state = sysready_table.hget('SYSTEM_STATE', 'Status')
        if not status or sysready_state.lower() != 'up':
            return False

        return True

    def _time_now(self):
        return get_formatted_time()

    def _update_dp_dpu_state(self, state):
        self.dpu_state_table.hset(self.name, self.DP_STATE, state)
        self.dpu_state_table.hset(self.name, self.DP_UPDATE_TIME, self._time_now())

    def _update_cp_dpu_state(self, state):
        self.dpu_state_table.hset(self.name, self.CP_STATE, state)
        self.dpu_state_table.hset(self.name, self.CP_UPDATE_TIME, self._time_now())

    def get_dp_state(self):
        return 'up' if self._get_dp_state() else 'down'

    def get_cp_state(self):
        return 'up' if self._get_cp_state() else 'down'

    def update_state(self):

        dp_current_state = self.get_dp_state()
        _, dp_prev_state = self.dpu_state_table.hget(self.name, self.DP_STATE)

        if dp_current_state != dp_prev_state:
            self._update_dp_dpu_state(dp_current_state)

        cp_current_state = self.get_cp_state()
        _, cp_prev_state = self.dpu_state_table.hget(self.name, self.CP_STATE)

        if cp_current_state != cp_prev_state:
            self._update_cp_dpu_state(cp_current_state)

    def deinit(self):
        self._update_dp_dpu_state('down')
        self._update_cp_dpu_state('down')


#
# Daemon =======================================================================
#


class ChassisdDaemon(daemon_base.DaemonBase):

    FATAL_SIGNALS = [signal.SIGINT, signal.SIGTERM]
    NONFATAL_SIGNALS = [signal.SIGHUP]

    def __init__(self, log_identifier, chassis):
        super(ChassisdDaemon, self).__init__(log_identifier)

        self.stop = threading.Event()

        self.platform_chassis = chassis

        for signum in self.FATAL_SIGNALS + self.NONFATAL_SIGNALS:
            try:
                signal.signal(signum, self.signal_handler)
            except Exception as e:
                self.log_error(f"Cannot register handler for {signum}: {e}")

    # Override signal handler from DaemonBase
    def signal_handler(self, sig, frame):
        global exit_code

        if sig in self.FATAL_SIGNALS:
            exit_code = 128 + sig  # Make sure we exit with a non-zero code so that supervisor will try to restart us
            self.log_info("Caught {} signal '{}' - exiting...".format(exit_code,SIGNALS_TO_NAMES_DICT[sig]))
            self.stop.set()
        elif sig in self.NONFATAL_SIGNALS:
            self.log_info("Caught signal '{}' - ignoring...".format(SIGNALS_TO_NAMES_DICT[sig]))
        else:
            self.log_warning("Caught unhandled signal '{}' - ignoring...".format(SIGNALS_TO_NAMES_DICT[sig]))

    def submit_dpu_callback(self, module_index, admin_state):
        try_get(self.module_updater.chassis.get_module(module_index).set_admin_state, admin_state, default=False)
        pass

    def set_initial_dpu_admin_state(self):
        """Send admin_state trigger once to modules those are powered up"""
        threads = []
        for module_index in range(0, self.module_updater.num_modules):
            op = None
            # Get operational state of DPU
            module_name = self.platform_chassis.get_module(module_index).get_name()
            operational_state = self.platform_chassis.get_module(module_index).get_oper_status()

            try:
                # Get admin state of DPU
                admin_state = self.module_updater.get_module_admin_status(module_name)
                if admin_state == 'empty' and operational_state != ModuleBase.MODULE_STATUS_OFFLINE:
                    # shutdown DPU
                    op = MODULE_ADMIN_DOWN

                # Initialize DPU_STATE DB table on bootup
                dpu_state_key = "DPU_STATE|" + module_name
                if operational_state == ModuleBase.MODULE_STATUS_ONLINE:
                    op_state = 'up'
                else:
                    op_state = 'down'
                self.module_updater.update_dpu_state(dpu_state_key, op_state)

                if op is not None:
                    # Create and start a thread for the DPU logic
                    thread = threading.Thread(target=self.submit_dpu_callback, args=(module_index, op))
                    thread.daemon = True  # Set as a daemon thread
                    thread.start()
                    threads.append(thread)

            except Exception as e:
                self.log_error(f"Error in run: {str(e)}", exc_info=True)

        # Wait for all threads to finish
        for thread in threads:
            thread.join()

    # Run daemon
    def run(self):
        self.log_info("Starting up...")

        # Check if module list is populated
        self.smartswitch = self.platform_chassis.is_smartswitch()
        self.log_info("smartswitch: {}".format(self.smartswitch))

        if self.smartswitch:
            self.module_updater = SmartSwitchModuleUpdater(SYSLOG_IDENTIFIER, self.platform_chassis)
        else:
            my_slot = try_get(self.platform_chassis.get_my_slot, default=INVALID_SLOT)
            supervisor_slot = try_get(self.platform_chassis.get_supervisor_slot, default=INVALID_SLOT)
            self.module_updater = ModuleUpdater(SYSLOG_IDENTIFIER, self.platform_chassis, my_slot, supervisor_slot)
        self.module_updater.modules_num_update()

        if not self.smartswitch:
            if ((self.module_updater.my_slot == INVALID_SLOT) or
                    (self.module_updater.supervisor_slot == INVALID_SLOT)):
                self.log_error("Chassisd not supported for this platform")
                sys.exit(CHASSIS_NOT_SUPPORTED)

        # Start configuration manager task
        if self.smartswitch:
            config_manager = SmartSwitchConfigManagerTask()
            config_manager.task_run()
        elif self.module_updater.supervisor_slot == self.module_updater.my_slot:
            config_manager = ConfigManagerTask()
            config_manager.task_run()
        else:
            config_manager = None

        # Start main loop
        self.log_info("Start daemon main loop")

        # Set the initial DPU admin state for SmartSwitch
        if self.smartswitch:
            self.set_initial_dpu_admin_state()

        while not self.stop.wait(CHASSIS_INFO_UPDATE_PERIOD_SECS):
            self.module_updater.module_db_update()
            self.module_updater.check_midplane_reachability()
            self.module_updater.module_down_chassis_db_cleanup()

        self.log_info("Stop daemon main loop")

        if config_manager is not None:
            config_manager.task_stop()

        # Delete all the information from DB and then exit
        self.module_updater.deinit()

        self.log_info("Shutting down...")


class DpuStateManagerTask(ProcessTaskBase):

    def __init__(self, log_identifier, dpu_state_updater):
        super(DpuStateManagerTask, self).__init__()

        self.logger = logger.Logger(log_identifier)
        self.dpu_state_updater = dpu_state_updater
        self.state_db = daemon_base.db_connect('STATE_DB')
        self.app_db = daemon_base.db_connect('APPL_DB')

    def task_worker(self):
        sel = swsscommon.Select()
        selectable = [
            swsscommon.SubscriberStateTable(self.app_db, 'PORT_TABLE'),
            swsscommon.SubscriberStateTable(self.state_db, 'SYSTEM_READY')
        ]

        for s in selectable:
            sel.addSelectable(s)

        try:
            while True:
                (state, c) = sel.select(SELECT_TIMEOUT)

                if state == swsscommon.Select.TIMEOUT:
                    continue

                if state != swsscommon.Select.OBJECT:
                    continue

                for s in selectable:
                    s.pops()

                self.dpu_state_updater.update_state()

        except KeyboardInterrupt:
            pass


class DpuChassisdDaemon(ChassisdDaemon):

    def run(self):
        self.log_info("Starting up...")

        poll_dpu_state = True
        if not try_get(self.platform_chassis.get_dataplane_state, default=None) and not \
                try_get(self.platform_chassis.get_controlplane_state, default=None):
            poll_dpu_state = False

        dpu_updater = DpuStateUpdater(SYSLOG_IDENTIFIER, self.platform_chassis)
        dpu_state_mng = None

        if not poll_dpu_state:
            dpu_state_mng = DpuStateManagerTask(SYSLOG_IDENTIFIER, dpu_updater)
            dpu_state_mng.task_run()

        # Start main loop
        self.log_info("Start daemon main loop")

        while not self.stop.wait(CHASSIS_INFO_UPDATE_PERIOD_SECS):
            if poll_dpu_state:
                dpu_updater.update_state()

        self.log_info("Stop daemon main loop")

        if dpu_state_mng:
            dpu_state_mng.task_stop()

        dpu_updater.deinit()

        self.log_info("Shutting down...")


#
# Main =========================================================================
#


def main():
    global exit_code

    chassis = get_chassis()

    if chassis.is_smartswitch() and chassis.is_dpu():
        chassisd = DpuChassisdDaemon(SYSLOG_IDENTIFIER, chassis)
    else:
        chassisd = ChassisdDaemon(SYSLOG_IDENTIFIER, chassis)

    chassisd.run()

    sys.exit(exit_code)

if __name__ == '__main__':
    main()
